# basics
import os, sys
import pickle
import random
import numpy as np
from PIL import Image

# torch modules
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset
from torchvision.utils import save_image

# torchvision modules
import torchvision
from torchvision.transforms import transforms
from torchvision import datasets
import matplotlib.pyplot as plt

# custom libs
sys.path.append('..')
from utils.io import load_from_csvfile, load_from_h5file
import os
from torchvision.datasets.folder import default_loader


# ------------------------------------------------------------------------------
#   Dataset loader
# ------------------------------------------------------------------------------
def load_dataset(dataset, datapth, batchsize, augment, kwargs, normalize=True):
    # MNIST
    if 'mnist' == dataset:
        train_loader, valid_loader = \
            load_mnist(batchsize, augment, kwargs, normalize=normalize)

    # CIFAR10
    elif 'cifar10' == dataset:
        train_loader, valid_loader = \
            load_cifar10(batchsize, augment, kwargs, normalize=normalize)

    # ImageNet
    elif 'imagenet' == dataset:
        train_loader, valid_loader = \
            load_imagenet(batchsize, argument, kwargs, normalize=normalize)
    
    # Undefined dataset
    else:
        assert False, ('Error: invalid dataset name [{}]'.format(dataset))

    return train_loader, valid_loader


# ------------------------------------------------------------------------------
#   Actual Loaders
# ------------------------------------------------------------------------------
def load_mnist(batchsize, augment, kwargs, normalize=True):
    if normalize:
        if augment == '':
            train_loader = torch.utils.data.DataLoader(
                    torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                            train=True, download=True,
                                            transform=transforms.Compose([
                                                    transforms.ToTensor(),
                                                    transforms.Normalize((0.1307,), (0.3081,)),
                                            ])),
                        batch_size=batchsize, shuffle=True, **kwargs)
            valid_loader = torch.utils.data.DataLoader(
                torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                        train=False, download=True,
                                        transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.1307,), (0.3081,)),
                                        ])),
                    batch_size=batchsize, shuffle=False, **kwargs)
        
        elif augment == 'rotation':
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                        train=True, download=True,
                                        transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.RandomRotation(30),
                                                transforms.Normalize((0.1307,), (0.3081,)),
                                        ])),
                    batch_size=batchsize, shuffle=True, **kwargs)
            valid_loader = torch.utils.data.DataLoader(
                torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                        train=False, download=True,
                                        transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.1307,), (0.3081,)),
                                        ])),
                    batch_size=batchsize, shuffle=False, **kwargs)
        elif augment == 'hflip':
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                        train=True, download=True,
                                        transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.RandomHorizontalFlip(p=0.25),
                                                transforms.Normalize((0.1307,), (0.3081,)),
                                        ])),
                    batch_size=batchsize, shuffle=True, **kwargs)
            valid_loader = torch.utils.data.DataLoader(
                torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                        train=False, download=True,
                                        transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.1307,), (0.3081,)),
                                        ])),
                    batch_size=batchsize, shuffle=False, **kwargs)

    else:
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                       train=True, download=True,
                                       transform=transforms.Compose([
                                            transforms.ToTensor(),
                                       ])),
                batch_size=batchsize, shuffle=True, **kwargs)
        valid_loader = torch.utils.data.DataLoader(
            torchvision.datasets.MNIST(root='datasets/originals/mnist',
                                       train=False, download=True,
                                       transform=transforms.Compose([
                                            transforms.ToTensor(),
                                       ])),
                batch_size=batchsize, shuffle=False, **kwargs)

    return train_loader, valid_loader


def load_cifar10(batchsize, augment, kwargs, normalize=True):
    if normalize:
        if augment == '':
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                            train=True, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                # transforms.RandomHorizontalFlip(), 
                                                # transforms.RandomCrop(32, padding=4),
                                                transforms.RandomCrop(32, padding=4),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                    (0.2023, 0.1994, 0.2010)),
                                            ])),
                    batch_size=batchsize, shuffle=True)
            valid_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                            train=False, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                    (0.2023, 0.1994, 0.2010)),
                                            ])),
                    batch_size=batchsize, shuffle=False)
        elif augment == 'rotation':
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                            train=True, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.RandomRotation(30),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                    (0.2023, 0.1994, 0.2010)),
                                            ])),
                    batch_size=batchsize, shuffle=True, **kwargs)
            valid_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                            train=False, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                    (0.2023, 0.1994, 0.2010)),
                                            ])),
                    batch_size=batchsize, shuffle=False)
        elif augment == 'hflip':
            train_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                            train=True, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.RandomHorizontalFlip(p=0.25),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                    (0.2023, 0.1994, 0.2010)),
                                            ])),
                    batch_size=batchsize, shuffle=True, **kwargs)
            valid_loader = torch.utils.data.DataLoader(
                torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                            train=False, download=True,
                                            transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465),
                                                                    (0.2023, 0.1994, 0.2010)),
                                            ])),
                    batch_size=batchsize, shuffle=False, **kwargs)
        
    else:
        train_loader = torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                         train=True, download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                         ])),
                batch_size=batchsize, shuffle=True, **kwargs)
        valid_loader = torch.utils.data.DataLoader(
            torchvision.datasets.CIFAR10(root='datasets/originals/cifar10',
                                         train=False, download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                         ])),
                batch_size=batchsize, shuffle=False, **kwargs)
    print(type(train_loader))
    return train_loader, valid_loader

def custom_loader(path):
    try:
        return datasets.folder.default_loader(path)
    except FileNotFoundError:
        return None
    
class CustomTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        if img is None:
            return img  # Return None if image is None
        
        return self.transform(img)

def load_imagenet(batchsize, augment, kwargs, normalize=True):
    print("Loading imagenet data")
    if normalize:
        train_transform = transforms.Compose([
            CustomTransform(transforms.RandomResizedCrop(224)),
            CustomTransform(transforms.RandomHorizontalFlip()),
            CustomTransform(transforms.ToTensor()),
            CustomTransform(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        ])

        valid_transform = transforms.Compose([
            CustomTransform(transforms.Resize(256)),
            CustomTransform(transforms.CenterCrop(224)),
            CustomTransform(transforms.ToTensor()),
            CustomTransform(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        ])

        dataset_path = "path to dataset"

        train_dataset = datasets.ImageFolder(dataset_path + "/train", transform=train_transform)
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=1)
        val_dataset = datasets.ImageFolder(dataset_path + "/val", transform=valid_transform)
        valid_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=1)
       

    return train_loader, valid_loader
